{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "GEN_7_Attention.ipynb",
      "provenance": [],
      "authorship_tag": "ABX9TyMFvVtE/VzSgx0gyQROiPvT",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/cxbxmxcx/GenReality/blob/master/GEN_7_Attention.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ZeTjyZSarOGt",
        "outputId": "0b7db2ee-136a-4055-d83f-170516f02fbc"
      },
      "source": [
        "#@title Imports and Input tensor x\r\n",
        "import torch\r\n",
        "import numpy as np\r\n",
        "\r\n",
        "x = np.random.randint(0,3,(3,4))\r\n",
        "x = torch.tensor(x, dtype=torch.float32)\r\n",
        "print(x)"
      ],
      "execution_count": 39,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "tensor([[1., 1., 0., 1.],\n",
            "        [1., 0., 1., 2.],\n",
            "        [2., 0., 1., 1.]])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "vQx7HPtDqjGG",
        "outputId": "50f9f0a0-cdc9-4151-ba1e-0a8f73311b62"
      },
      "source": [
        "#@title Create random key, query and value weights\r\n",
        "w_key = np.random.randint(0,2,(4,3))\r\n",
        "w_query = np.random.randint(0,2,(4,3))\r\n",
        "w_value = np.random.randint(0,4,(4,3))\r\n",
        "\r\n",
        "w_key = torch.tensor(w_key, dtype=torch.float32)\r\n",
        "w_query = torch.tensor(w_query, dtype=torch.float32)\r\n",
        "w_value = torch.tensor(w_value, dtype=torch.float32)\r\n",
        "print(w_key)\r\n",
        "print(w_query)\r\n",
        "print(w_value)"
      ],
      "execution_count": 40,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "tensor([[0., 0., 0.],\n",
            "        [1., 0., 1.],\n",
            "        [1., 0., 0.],\n",
            "        [0., 0., 0.]])\n",
            "tensor([[1., 0., 0.],\n",
            "        [0., 0., 0.],\n",
            "        [0., 1., 1.],\n",
            "        [0., 1., 1.]])\n",
            "tensor([[2., 2., 0.],\n",
            "        [0., 0., 3.],\n",
            "        [2., 1., 0.],\n",
            "        [2., 2., 1.]])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "_k8EOig2pP2t",
        "outputId": "a16da8b5-82c9-487b-a0db-efbea3fbf462"
      },
      "source": [
        "#@title Multiply by x\r\n",
        "keys = x @ w_key\r\n",
        "querys = x @ w_query\r\n",
        "values = x @ w_value\r\n",
        "\r\n",
        "print(keys)\r\n",
        "print(querys)\r\n",
        "print(values)"
      ],
      "execution_count": 41,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "tensor([[1., 0., 1.],\n",
            "        [1., 0., 0.],\n",
            "        [1., 0., 0.]])\n",
            "tensor([[1., 1., 1.],\n",
            "        [1., 3., 3.],\n",
            "        [2., 2., 2.]])\n",
            "tensor([[4., 4., 4.],\n",
            "        [8., 7., 2.],\n",
            "        [8., 7., 1.]])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "mN60VqFgpeTm",
        "outputId": "4c6714a3-6e70-486e-abc9-a00acbd9351a"
      },
      "source": [
        "#@title Calculate Attention Score\r\n",
        "attn_scores = querys @ keys.T\r\n",
        "print(attn_scores)"
      ],
      "execution_count": 42,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "tensor([[2., 1., 1.],\n",
            "        [4., 1., 1.],\n",
            "        [4., 2., 2.]])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "3ThH_XYrp-xt",
        "outputId": "cf6def98-a398-4ef8-a903-ceef39b0fead"
      },
      "source": [
        "#@title Apply Softmax\r\n",
        "from torch.nn.functional import softmax\r\n",
        "\r\n",
        "attn_scores_softmax = softmax(attn_scores, dim=-1)\r\n",
        "print(attn_scores_softmax)"
      ],
      "execution_count": 43,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "tensor([[0.5761, 0.2119, 0.2119],\n",
            "        [0.9094, 0.0453, 0.0453],\n",
            "        [0.7870, 0.1065, 0.1065]])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Gdi6SzEBqS0a",
        "outputId": "0d2a5063-1a21-48d5-fbaa-a3eb9fadaf66"
      },
      "source": [
        "#@title Calculate Weighted Values and Outputs\r\n",
        "weighted_values = values[:,None] * attn_scores_softmax.T[:,:,None]\r\n",
        "\r\n",
        "outputs = weighted_values.sum(dim=0)\r\n",
        "print(outputs)"
      ],
      "execution_count": 44,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "tensor([[5.6955, 5.2716, 2.9403],\n",
            "        [4.3622, 4.2717, 3.7736],\n",
            "        [4.8521, 4.6390, 3.4675]])\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}